{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['finetune-t5-small-standard-bahasa-cased/checkpoint-190000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-200000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-210000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-220000',\n",
       " 'finetune-t5-small-standard-bahasa-cased/checkpoint-230000']"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "checkpoints = sorted(glob('finetune-t5-small-standard-bahasa-cased/checkpoint-*'))\n",
    "checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "\n",
    "tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.models.bart.modeling_bart import shift_tokens_right\n",
    "from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput\n",
    "from transformers import T5Model, T5Config\n",
    "from typing import List, Optional, Tuple, Union\n",
    "from torch import nn\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BartClassificationHead(nn.Module):\n",
    "    \"\"\"Head for sentence-level classification tasks.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        input_dim: int,\n",
    "        inner_dim: int,\n",
    "        num_classes: int,\n",
    "        pooler_dropout: float,\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.dense = nn.Linear(input_dim, inner_dim)\n",
    "        self.dropout = nn.Dropout(p=pooler_dropout)\n",
    "        self.out_proj = nn.Linear(inner_dim, num_classes)\n",
    "\n",
    "    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.dense(hidden_states)\n",
    "        hidden_states = torch.tanh(hidden_states)\n",
    "        hidden_states = self.dropout(hidden_states)\n",
    "        hidden_states = self.out_proj(hidden_states)\n",
    "        return hidden_states\n",
    "\n",
    "\n",
    "class T5ForSequenceClassification(T5Model):\n",
    "    def __init__(self, config: T5Config, **kwargs):\n",
    "        super().__init__(config, **kwargs)\n",
    "        self.classification_head = BartClassificationHead(\n",
    "            config.d_model,\n",
    "            config.d_model,\n",
    "            config.num_labels,\n",
    "            config.dropout_rate,\n",
    "        )\n",
    "        self._init_weights(self.classification_head.dense)\n",
    "        self._init_weights(self.classification_head.out_proj)\n",
    "\n",
    "    def forward(\n",
    "        self,\n",
    "        input_ids: torch.LongTensor = None,\n",
    "        attention_mask: Optional[torch.Tensor] = None,\n",
    "        decoder_input_ids: Optional[torch.LongTensor] = None,\n",
    "        decoder_attention_mask: Optional[torch.LongTensor] = None,\n",
    "        head_mask: Optional[torch.Tensor] = None,\n",
    "        decoder_head_mask: Optional[torch.Tensor] = None,\n",
    "        cross_attn_head_mask: Optional[torch.Tensor] = None,\n",
    "        encoder_outputs: Optional[List[torch.FloatTensor]] = None,\n",
    "        inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,\n",
    "        labels: Optional[torch.LongTensor] = None,\n",
    "        use_cache: Optional[bool] = None,\n",
    "        output_attentions: Optional[bool] = None,\n",
    "        output_hidden_states: Optional[bool] = None,\n",
    "        return_dict: Optional[bool] = None,\n",
    "    ):\n",
    "\n",
    "        decoder_input_ids = shift_tokens_right(input_ids,\n",
    "                                               self.config.pad_token_id, self.config.decoder_start_token_id)\n",
    "        outputs = super().forward(\n",
    "            input_ids,\n",
    "            attention_mask=attention_mask,\n",
    "            decoder_input_ids=decoder_input_ids,\n",
    "            decoder_attention_mask=decoder_attention_mask,\n",
    "            head_mask=head_mask,\n",
    "            decoder_head_mask=decoder_head_mask,\n",
    "            cross_attn_head_mask=cross_attn_head_mask,\n",
    "            encoder_outputs=encoder_outputs,\n",
    "            inputs_embeds=inputs_embeds,\n",
    "            decoder_inputs_embeds=decoder_inputs_embeds,\n",
    "            use_cache=use_cache,\n",
    "            output_attentions=output_attentions,\n",
    "            output_hidden_states=output_hidden_states,\n",
    "            return_dict=return_dict,\n",
    "        )\n",
    "        hidden_states = outputs[0]  # last hidden state\n",
    "\n",
    "        eos_mask = input_ids.eq(self.config.eos_token_id)\n",
    "\n",
    "        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:\n",
    "            raise ValueError(\"All examples must have the same number of <eos> tokens.\")\n",
    "        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[\n",
    "            :, -1, :\n",
    "        ]\n",
    "        logits = self.classification_head(sentence_representation)\n",
    "\n",
    "        loss = None\n",
    "\n",
    "        if not return_dict:\n",
    "            output = (logits,) + outputs[1:]\n",
    "            return ((loss,) + output) if loss is not None else output\n",
    "\n",
    "        return Seq2SeqSequenceClassifierOutput(\n",
    "            loss=loss,\n",
    "            logits=logits,\n",
    "            past_key_values=outputs.past_key_values,\n",
    "            decoder_hidden_states=outputs.decoder_hidden_states,\n",
    "            decoder_attentions=outputs.decoder_attentions,\n",
    "            cross_attentions=outputs.cross_attentions,\n",
    "            encoder_last_hidden_state=outputs.encoder_last_hidden_state,\n",
    "            encoder_hidden_states=outputs.encoder_hidden_states,\n",
    "            encoder_attentions=outputs.encoder_attentions,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = T5ForSequenceClassification.from_pretrained(checkpoints[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1 = 'tlong order foodpanda'\n",
    "s2 = 'This text is about makanan'\n",
    "s = f'ayat1: {s1} ayat2: {s2}'\n",
    "strings = []\n",
    "strings.append(s)\n",
    "\n",
    "input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model(**padded, return_dict = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0659, 0.9341]], grad_fn=<SoftmaxBackward0>)"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entail_contradiction_logits = outputs.logits[:,[0,2]]\n",
    "probs = entail_contradiction_logits.softmax(dim=1)\n",
    "probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.9340567], dtype=float32)"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "probs[:,1].detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.7766, 0.2234],\n",
       "        [0.4111, 0.5889],\n",
       "        [0.0381, 0.9619],\n",
       "        [0.0099, 0.9901],\n",
       "        [0.0392, 0.9608]], grad_fn=<SoftmaxBackward0>)"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'\n",
    "tags = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki']\n",
    "strings = []\n",
    "for t in tags:\n",
    "    s = f'ayat1: {s1} ayat2: This text is about {t}'\n",
    "    strings.append(s)\n",
    "input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model(**padded, return_dict = True)\n",
    "entail_contradiction_logits = outputs.logits[:,[0,2]]\n",
    "probs = entail_contradiction_logits.softmax(dim=1)\n",
    "probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0233, 0.4299, 0.9591, 0.9285, 0.0393, 0.0804, 0.9432, 0.9826],\n",
       "       grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s1 = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'\n",
    "tags = ['makan', 'makanan', 'buku', 'kerajaan', 'food delivery',\n",
    "                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat']\n",
    "strings = []\n",
    "for t in tags:\n",
    "    s = f'ayat1: {s1} ayat2: ayat ini berkaitan tentang {t}'\n",
    "    strings.append(s)\n",
    "input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]\n",
    "padded = tokenizer.pad(input_ids, padding='longest')\n",
    "outputs = model(**padded, return_dict = True)\n",
    "entail_contradiction_logits = outputs.logits[:,[0,2]]\n",
    "probs = entail_contradiction_logits.softmax(dim=1)[:, 1]\n",
    "probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/utils/_deprecation.py:38: FutureWarning: Deprecated positional argument(s) used in 'create_repo': pass token='finetune-mnli-t5-small-standard-bahasa-cased' as keyword args. From version 0.12 passing these as positional arguments will result in an error,\n",
      "  warnings.warn(\n",
      "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/hf_api.py:102: FutureWarning: `name` and `organization` input arguments are deprecated and will be removed in v0.10. Pass `repo_id` instead.\n",
      "  warnings.warn(\n",
      "/home/ubuntu/.local/lib/python3.8/site-packages/huggingface_hub/hf_api.py:681: FutureWarning: `create_repo` now takes `token` as an optional positional argument. Be sure to adapt your code!\n",
      "  warnings.warn(\n",
      "Cloning https://huggingface.co/mesolitica/finetune-mnli-t5-small-standard-bahasa-cased into local empty directory.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7524c7b16f142d6a7985a4708f2b0a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file pytorch_model.bin:   0%|          | 4.00k/232M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-mnli-t5-small-standard-bahasa-cased\n",
      "   692aca7..b4d9003  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-mnli-t5-small-standard-bahasa-cased/commit/b4d9003ca5d8099a0127487a0b5cb44eadc0a4b0'"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('finetune-mnli-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0eab3f5b60f44e93aa3659be158b0371",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload file spiece.model:   1%|          | 4.00k/784k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "remote: Scanning LFS files for validity, may be slow...        \n",
      "remote: LFS file scan complete.        \n",
      "To https://huggingface.co/mesolitica/finetune-mnli-t5-small-standard-bahasa-cased\n",
      "   b4d9003..f5a5a51  main -> main\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'https://huggingface.co/mesolitica/finetune-mnli-t5-small-standard-bahasa-cased/commit/f5a5a518b2161f51da794adcdfa84f0c211d07c5'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.push_to_hub('finetune-mnli-t5-small-standard-bahasa-cased', organization='mesolitica')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10000it [04:01, 41.47it/s]\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "predicted, actual = [], []\n",
    "count = 0\n",
    "with open('translated-mnli-dev_matched.jsonl') as fopen:\n",
    "    for l in tqdm(fopen):\n",
    "        data = json.loads(l)\n",
    "        if data['gold_label'] == '-':\n",
    "            continue\n",
    "        s = f\"ayat1: {data['translate'][0]} ayat2: {data['translate'][1]}\"\n",
    "        input_ids = tokenizer.encode(s, return_tensors='pt')\n",
    "        logits = model(input_ids = input_ids, return_dict = True).logits[0].detach().numpy()\n",
    "        predicted.append(model.config.id2label[np.argmax(logits)])\n",
    "        actual.append(data['gold_label'])\n",
    "        \n",
    "        count += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction    0.82011   0.79010   0.80483      3335\n",
      "   entailment    0.77005   0.82864   0.79827      3233\n",
      "      neutral    0.75184   0.72313   0.73721      3247\n",
      "\n",
      "     accuracy                        0.78064      9815\n",
      "    macro avg    0.78067   0.78063   0.78010      9815\n",
      " weighted avg    0.78103   0.78064   0.78030      9815\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "print(\n",
    "    metrics.classification_report(\n",
    "        predicted, actual,\n",
    "        digits = 5\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
